""" DiffBC Policies Implementation """
from functools import partial
from typing import Any, Tuple, NamedTuple, List, Dict, Union, Type, Optional, Callable

import gym
import jax
import jax.numpy as jnp
import haiku as hk
import optax

from sb3_jax.common.policies import BasePolicy
from sb3_jax.common.norm_layers import BaseNormLayer
from sb3_jax.common.jax_layers import BaseFeaturesExtractor, FlattenExtractor
from sb3_jax.common.preprocessing import get_flattened_obs_dim, get_act_dim
from sb3_jax.common.type_aliases import GymEnv, MaybeCallback, Schedule
from sb3_jax.common.utils import get_dummy_decision_transformer, get_dummy_obs, get_dummy_act, print_b

from sb3_jax.du.policies import DiffusionBetaScheduler
from diffgro.common.models.diffusion import MLPDiffusion, Diffusion


class Actor(BasePolicy):
    """ actor class for diffbc policy """
    def __init__(
        self,
        observation_space: gym.spaces.Space,
        action_space: gym.spaces.Space,
        net_arch: List[int],
        activation_fn: str = 'mish',
        # embedding
        sem_dim: int = 512,
        emb_dim: int = 64,
        # diffusion
        n_denoise: int = 20,
        predict_epsilon: bool = False,  # predict noise / original
        beta_scheduler: str = 'cosine',
        seed: int = 1,
    ):
        super(Actor, self).__init__(
            observation_space,
            action_space,
            squash_output=False,
            seed=seed,
        )

        self.net_arch = net_arch
        self.activation_fn = activation_fn
        
        self.sem_dim = sem_dim
        self.emb_dim = emb_dim
        self.n_denoise = n_denoise
        self.predict_epsilon = predict_epsilon
        self.beta_scheduler = beta_scheduler
        self.ddpm_dict = DiffusionBetaScheduler(None, None, n_denoise, beta_scheduler).schedule()

        self.obs_dim = get_flattened_obs_dim(self.observation_space)
        self.act_dim = get_act_dim(self.action_space)
        self.out_dim = self.act_dim

        self._build()

    def _get_constructor_parameters(self) -> Dict[str, Any]:
        data = super()._get_constructor_parameters()
        return data

    def _build_act(self, batch_keys: Dict[str, jax.Array]) -> hk.Module:
        mlp = MLPDiffusion(
            emb_dim=self.emb_dim,
            out_dim=self.out_dim,
            net_arch=self.net_arch,
            batch_keys=batch_keys,
            activation_fn=self.activation_fn
        )
        return Diffusion(
            diffusion=mlp,
            n_denoise=self.n_denoise,
            ddpm_dict=self.ddpm_dict,
            guidance_weight=1.0,    # no guidance
            predict_epsilon=self.predict_epsilon,   # predict noise
            denoise_type='ddpm',
        )

    def _build(self) -> None:
        # dummy inputs
        dummy_task = jax.random.normal(next(self.rng), shape=(1, self.sem_dim))
        dummy_obs = get_dummy_obs(self.observation_space)
        dummy_act = get_dummy_act(self.action_space)
        dummy_t = jnp.array([[1.]])

        def fn_act(x_t: jax.Array, batch_dict: Dict[str, jax.Array], t: jax.Array, denoise: bool, deterministic: bool):
            act = self._build_act(batch_keys=["task", "obs"])
            return act(x_t, batch_dict, t, denoise, deterministic)
        params, self.pi = hk.transform(fn_act)
        batch_dict = {"task": dummy_task, "obs": dummy_obs}
        self.params = params(next(self.rng), dummy_act, batch_dict, dummy_t, denoise=False, deterministic=False)

    @partial(jax.jit, static_argnums=(0,4,5))
    def _pi(
        self, 
        x_t: jax.Array, 
        batch_dict: Dict[str, jax.Array], 
        t: jax.Array, 
        denoise: bool, 
        deterministic: bool, 
        params: hk.Params, 
        rng=None
    ) -> Tuple[jax.Array, Dict[str, jax.Array]]:
        return self.pi(params, rng, x_t, batch_dict, t, denoise, deterministic)

    def _predict(
        self,
        x_t: jax.Array,
        task: jax.Array,
        obs: jax.Array,
        t: int,
        deterministic: bool = False,
    ) -> Tuple[jax.Array, Dict[str, jax.Array]]:
        batch_dict = {"task": task, "obs": obs}
        ts = jnp.full((obs.shape[0], 1), t)
        return self._pi(x_t, batch_dict, ts, False, deterministic, self.params, next(self.rng))

    def _denoise(
        self,
        task: jax.Array,
        obs: jax.Array,
        deterministic: bool = False,
    ) -> Tuple[jax.Array, Dict[str, jax.Array]]:
        batch_dict = {"task": task, "obs": obs}
        return self._pi(None, batch_dict, None, True, deterministic, self.params, next(self.rng))

    def _load_jax_params(self, params: Dict[str, hk.Params]) -> None:
        print_b("[diffbc/actor]: loading params")
        self.params = params["pi_params"]

    
class DiffBCPlannerPolicy(BasePolicy):
    """ policy class for diffbc """
    def __init__(
        self,
        observation_space: gym.spaces.Space,
        action_space: gym.spaces.Space,
        lr_schedule: float,
        net_arch: Optional[List[int]] = None,
        activation_fn: Union[str, Callable[[jax.Array], jax.Array]] = 'mish',
        sem_dim: int = 512,     # semantic dimension
        emb_dim: int = 64,      # embedding dimension
        # diffusion
        n_denoise: int = 20,    # denoising timestep
        predict_epsilon: bool = False,  # predict noise / original
        beta_scheduler: str = 'linear', # denoising scheduler
        # others
        squash_output: bool = False,
        features_extractor_class: Type[BaseFeaturesExtractor] = FlattenExtractor,
        features_extractor_kwargs: Optional[Dict[str, Any]] = None,
        normalize_images: bool = True,
        optimizer_class: Callable = optax.adamw,
        optimizer_kwargs: Optional[Dict[str, Any]] = None,
        normalization_class: Type[BaseNormLayer] = None,
        normalization_kwargs: Optional[Dict[str, Any]] = None,
        seed: int = 1,
    ):
        super(DiffBCPlannerPolicy, self).__init__(
            observation_space,
            action_space,
            features_extractor_class,
            features_extractor_kwargs,
            optimizer_class=optimizer_class,
            optimizer_kwargs=optimizer_kwargs,
            normalization_class=normalization_class,
            normalization_kwargs=normalization_kwargs,
            squash_output=squash_output,
            seed=seed,
        )

        if net_arch is None:
            net_arch = dict(act=[128,128])
        self.act_arch = net_arch['act']
        self.activation_fn = activation_fn

        self.sem_dim = sem_dim
        self.emb_dim = emb_dim
        self.n_denoise = n_denoise
        self.predict_epsilon = predict_epsilon
        self.beta_scheduler = beta_scheduler
        
        # construct args
        self.net_args = {
            "observation_space": self.observation_space,
            "action_space": self.action_space,
            "activation_fn": self.activation_fn,
            "seed": seed,
        }

        # actor kwargs
        self.act_kwargs = self.net_args.copy()
        self.act_kwargs.update({
            "net_arch": self.act_arch,
            "sem_dim": sem_dim,
            "emb_dim": emb_dim,
            "n_denoise": n_denoise,
            "predict_epsilon": predict_epsilon,
            "beta_scheduler": beta_scheduler,
        })

        self._build(lr_schedule)

    def _get_constructor_parameters(self) -> Dict[str, Any]:
        data = super()._get_constructor_parameters()

        data.update(
            dict(
                observation_space=self.observation_space,
                action_space=self.action_space, 
                sem_dim=self.sem_dim,
                emb_dim=self.emb_dim,
                n_denoise=self.n_denoise,
                predict_epsilon=self.predict_epsilon,
                beta_scheduler=self.beta_scheduler,
                optimizer_class=self.optimizer_class,
                optimizer_kwargs=self.optimizer_kwargs,
                features_extractor_class=self.features_extractor_class,
                features_extractor_kwargs=self.features_extractor_kwargs,
                normalization_class=self.normalization_class,
                normalization_kwargs=self.normalization_kwargs,
            )
        )
        return data

    def _build(self, lr_schedule: Tuple[float]) -> None:
        if self.normalization_class is not None:
            self.normalization_layer = self.normalization_class(self.observation_space.shape, **self.normalization_kwargs)
        
        # make actor
        self.act = self.make_act()
        self.act.optim = self.optimizer_class(learning_rate=lr_schedule, **self.optimizer_kwargs)
        self.act.optim_state = self.act.optim.init(self.act.params)

    def make_act(self) -> Actor:
        return Actor(**self.act_kwargs)
        
    def _predict(
        self, 
        task: jax.Array,
        obs: jax.Array, 
        deterministic: bool = False
    ) -> Tuple[jax.Array, Dict[str, jax.Array]]:
        obs = self.preprocess(obs, training=False)
        return self.act._denoise(task, obs, deterministic)
